import sys
import torch
import os.path as osp
import os
import warnings
from .base import BaseModel
from ..dataset import DATASET_TYPE
from ..smp import *
from PIL import Image

"""
    Please follow the instructions to download ckpt.
    https://github.com/RBDash-Team/RBDash?tab=readme-ov-file#pretrained-weights
"""


class RBDash(BaseModel):
    INSTALL_REQ = True
    INTERLEAVE = False

    def __init__(self, model_path, root=None, conv_mode="qwen", **kwargs):
        from huggingface_hub import snapshot_download

        if root is None:
            raise ValueError(
                'Please set `root` to RBDash code directory, \
                which is cloned from here: "https://github.com/RBDash-Team/RBDash?tab=readme-ov-file" '
            )
        warnings.warn(
            "Please follow the instructions of RBDash to put the ckpt file in the right place, \
            which can be found at https://github.com/RBDash-Team/RBDash?tab=readme-ov-file#structure"
        )
        assert (
            model_path == "RBDash-Team/RBDash-v1.5"
        ), "We only support RBDash-v1.5 for now"
        sys.path.append(root)
        try:
            from rbdash.model.builder import load_pretrained_model
            from rbdash.mm_utils import get_model_name_from_path
        except Exception as err:
            logging.critical(
                "Please first install RBdash and set the root path to use RBdash, "
                'which is cloned from here: "https://github.com/RBDash-Team/RBDash?tab=readme-ov-file" '
            )
            raise err

        VLMEvalKit_path = os.getcwd()
        os.chdir(root)
        warnings.warn(
            'Please set `root` to RBdash code directory, \
            which is cloned from here: "https://github.com/RBDash-Team/RBDash?tab=readme-ov-file" '
        )
        try:
            model_name = get_model_name_from_path(model_path)
        except Exception as err:
            logging.critical(
                "Please follow the instructions of RBdash to put the ckpt file in the right place, "
                "which can be found at https://github.com/RBDash-Team/RBDash?tab=readme-ov-file#structure"
            )
            raise err

        download_model_path = snapshot_download(model_path)

        internvit_local_dir = "./model_zoo/OpenGVLab/InternViT-6B-448px-V1-5"
        os.makedirs(internvit_local_dir, exist_ok=True)
        snapshot_download(
            "OpenGVLab/InternViT-6B-448px-V1-5", local_dir=internvit_local_dir
        )

        convnext_local_dir = "./model_zoo/OpenAI/openclip-convnext-large-d-320-laion2B-s29B-b131K-ft-soup"
        os.makedirs(convnext_local_dir, exist_ok=True)
        snapshot_download(
            "laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup",
            local_dir=convnext_local_dir,
        )
        preprocessor_url = "https://huggingface.co/openai/clip-vit-large-patch14-336/blob/main/preprocessor_config.json"
        download_file_path = osp.join(convnext_local_dir, "preprocessor_config.json")
        if not osp.exists(download_file_path):
            print(f"download preprocessor to {download_file_path}")
            download_file(preprocessor_url, download_file_path)

        tokenizer, model, image_processor, image_processor_aux, context_len = (
            load_pretrained_model(
                download_model_path, None, model_name, device_map="auto"
            )
        )
        os.chdir(VLMEvalKit_path)
        self.model = model
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.image_processor_aux = image_processor_aux
        self.conv_mode = conv_mode

        if tokenizer.unk_token is None:
            tokenizer.unk_token = "<|endoftext|>"
        tokenizer.pad_token = tokenizer.unk_token

        kwargs_default = dict(
            temperature=float(0.2),
            num_beams=1,
            top_p=None,
            max_new_tokens=128,
            use_cache=True,
        )
        kwargs_default.update(kwargs)
        self.kwargs = kwargs_default

    def generate_inner(self, message, dataset=None):
        try:
            from rbdash.constants import (
                IMAGE_TOKEN_INDEX,
                DEFAULT_IMAGE_TOKEN,
                DEFAULT_IM_START_TOKEN,
                DEFAULT_IM_END_TOKEN,
            )
            from rbdash.conversation import conv_templates
            from rbdash.mm_utils import tokenizer_image_token, process_images
        except Exception as err:
            logging.critical(
                "Please first install RBdash and set the root path to use RBdash, "
                'which is cloned from here: "https://github.com/RBDash-Team/RBDash?tab=readme-ov-file" '
            )
            raise err

        prompt, image = self.message_to_promptimg(message, dataset=dataset)
        image = Image.open(image).convert("RGB")

        if self.model.config.mm_use_im_start_end:
            prompt = (
                DEFAULT_IM_START_TOKEN
                + DEFAULT_IMAGE_TOKEN
                + DEFAULT_IM_END_TOKEN
                + "\n"
                + prompt
            )
        else:
            prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt

        conv = conv_templates[self.conv_mode].copy()
        conv.append_message(conv.roles[0], prompt)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        input_ids = tokenizer_image_token(
            prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
        )
        input_ids = input_ids.unsqueeze(0).cuda()
        if hasattr(self.model.config, "image_size_aux"):
            if not hasattr(self.image_processor, "image_size_raw"):
                self.image_processor.image_size_raw = (
                    self.image_processor.crop_size.copy()
                )
            self.image_processor.crop_size["height"] = self.model.config.image_size_aux
            self.image_processor.crop_size["width"] = self.model.config.image_size_aux
            self.image_processor.size["shortest_edge"] = (
                self.model.config.image_size_aux
            )
            self.image_processor_aux.crop_size["height"] = (
                self.model.config.image_size_aux
            )
            self.image_processor_aux.crop_size["width"] = (
                self.model.config.image_size_aux
            )
            self.image_processor_aux.size["shortest_edge"] = (
                self.model.config.image_size_aux
            )
        image_tensor = process_images([image], self.image_processor, self.model.config)[
            0
        ]
        image_grid = getattr(self.model.config, "image_grid", 1)
        if hasattr(self.model.config, "image_size_aux"):
            raw_shape = [
                self.image_processor.image_size_raw["height"] * image_grid,
                self.image_processor.image_size_raw["width"] * image_grid,
            ]
            if self.image_processor is not self.image_processor_aux:
                image_tensor_aux = process_images(
                    [image], self.image_processor_aux, self.model.config
                )[0]
            else:
                image_tensor_aux = image_tensor
            image_tensor = torch.nn.functional.interpolate(
                image_tensor[None], size=raw_shape, mode="bilinear", align_corners=False
            )[0]
        else:
            image_tensor_aux = []
        if image_grid >= 2:
            raw_image = image_tensor.reshape(
                3,
                image_grid,
                self.image_processor.image_size_raw["height"],
                image_grid,
                self.image_processor.image_size_raw["width"],
            )
            raw_image = raw_image.permute(1, 3, 0, 2, 4)
            raw_image = raw_image.reshape(
                -1,
                3,
                self.image_processor.image_size_raw["height"],
                self.image_processor.image_size_raw["width"],
            )

            if getattr(self.model.config, "image_global", False):
                global_image = image_tensor
                if len(global_image.shape) == 3:
                    global_image = global_image[None]
                global_image = torch.nn.functional.interpolate(
                    global_image,
                    size=[
                        self.image_processor.image_size_raw["height"],
                        self.image_processor.image_size_raw["width"],
                    ],
                    mode="bilinear",
                    align_corners=False,
                )
                raw_image = torch.cat([raw_image, global_image], dim=0)
            image_tensor = raw_image.contiguous()

        images = image_tensor[None].to(
            dtype=self.model.dtype, device="cuda", non_blocking=True
        )
        if len(image_tensor_aux) > 0:
            images_aux = image_tensor_aux[None].to(
                dtype=self.model.dtype, device="cuda", non_blocking=True
            )
        else:
            images_aux = None

        with torch.inference_mode():
            output_ids = self.model.generate(
                input_ids,
                max_new_tokens=512,
                images=images,
                images_aux=images_aux,
                do_sample=True if self.kwargs["temperature"] > 0 else False,
                temperature=self.kwargs["temperature"],
                top_p=self.kwargs["top_p"],
                num_beams=self.kwargs["num_beams"],
            )

        outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[
            0
        ].strip()
        return outputs

    def use_custom_prompt(self, dataset):
        assert dataset is not None
        if listinstr(["MMDU", "MME-RealWorld", "MME-RealWorld-CN"], dataset):
            # For Multi-Turn we don't have custom prompt
            return False
        if "mme" in dataset.lower():
            return True
        elif "hallusionbench" in dataset.lower():
            return True
        elif "mmmu" in dataset.lower():
            return True
        elif "mmbench" in dataset.lower():
            return True
        return False

    def build_mme(self, line):
        question = line["question"]
        prompt = question + "Answer the question using a single word or phrase."
        return prompt

    def build_hallusionbench(self, line):
        question = line["question"]
        prompt = question + "\nAnswer the question using a single word or phrase."
        return prompt

    def build_mmbench(self, line):
        question = line["question"]
        options = {
            cand: line[cand]
            for cand in string.ascii_uppercase
            if cand in line and not pd.isna(line[cand])
        }
        options_prompt = "Options:\n"
        for key, item in options.items():
            options_prompt += f"{key}. {item}\n"
        hint = line["hint"] if ("hint" in line and not pd.isna(line["hint"])) else None
        prompt = ""
        if hint is not None:
            prompt += f"Hint: {hint}\n"
        prompt += f"Question: {question}\n"
        if len(options):
            prompt += options_prompt
            prompt += "Answer with the option's letter from the given choices directly."
        else:
            prompt += "Answer the question using a single word or phrase."
        return prompt

    def build_mmmu(self, line):
        question = line["question"]
        options = {
            cand: line[cand]
            for cand in string.ascii_uppercase
            if cand in line and not pd.isna(line[cand])
        }
        options_prompt = "Options:\n"
        for key, item in options.items():
            options_prompt += f"({key}) {item}\n"
        hint = line["hint"] if ("hint" in line and not pd.isna(line["hint"])) else None
        prompt = ""
        if hint is not None:
            prompt += f"Hint: {hint}\n"
        prompt += f"Question: {question}\n"
        if len(options):
            prompt += options_prompt
            prompt += "Answer with the option's letter from the given choices directly."
        else:
            prompt += "Answer the question using a single word or phrase."
        return prompt

    def build_prompt(self, line, dataset=None):
        assert dataset is None or isinstance(dataset, str)
        assert self.use_custom_prompt(dataset)
        tgt_path = self.dump_image(line, dataset)
        if "mme" in dataset.lower():
            prompt = self.build_mme(line)
        elif "hallusionbench" in dataset.lower():
            prompt = self.build_hallusionbench(line)
        elif "mmmu" in dataset.lower():
            prompt = self.build_mmmu(line)
        elif "mmbench" in dataset.lower():
            prompt = self.build_mmbench(line)

        ret = [dict(type="text", value=prompt)]
        ret.extend([dict(type="image", value=s) for s in tgt_path])
        return ret
